#!/usr/bin/python3
# coding=utf-8

# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================
import os
import sys
import logging

import numpy as np
import tensorflow as tf

IS_OUTPUT_TXT = False


class MatmulGenData:
    def __init__(self, m, n, k, b, is_trans_a, is_trans_b, is_bias, data_type_str, \
        a_format="ND", b_format="ND", c_format="ND", is_channel_split=False):
        self.m = m
        self.n = n
        self.k = k
        self.b = b
        self.is_trans_a = is_trans_a
        self.is_trans_b = is_trans_b
        self.is_bias = is_bias
        self.data_type_str = data_type_str
        self.a_format = a_format
        self.b_format = b_format
        self.c_format = c_format
        self.is_channel_split = is_channel_split


    @staticmethod
    def due_overflow(data):
        data = np.maximum(data, -65504)
        data = np.minimum(data, 65504)
        return data

    @staticmethod
    def nd_to_nz(matrix, shape, data_type, c0size):
        matrix = matrix.reshape((int(shape[1] / 16), 16, int(shape[2] / c0size), c0size))\
            .transpose(2, 0, 1, 3).astype(data_type)
        return matrix


    def tf_matmul(self, x1_gm_fp32, x2_gm_fp32, bias_gm_fp32=None):
        tf.compat.v1.disable_eager_execution()
        x1 = tf.compat.v1.placeholder(np.float32, shape=x1_gm_fp32.shape)
        x2 = tf.compat.v1.placeholder(np.float32, shape=x2_gm_fp32.shape)
        res_tf = tf.matmul(x1, x2, transpose_a=self.is_trans_a, transpose_b=self.is_trans_b)
        if self.is_bias:
            bias = tf.compat.v1.placeholder(np.float32, shape=bias_gm_fp32.shape)
            res_tf = tf.add(res_tf, bias)

        with tf.compat.v1.Session() as sess:
            feed_dict = {
                x1: x1_gm_fp32,
                x2: x2_gm_fp32,
            }
            if self.is_bias:
                feed_dict[bias] = bias_gm_fp32
            res_tf = sess.run(res_tf, feed_dict=feed_dict)
        y_gm_fp32 = MatmulGenData.due_overflow(res_tf)
        return y_gm_fp32

    def gen_golden_data_fp16(self, work_dir, dst_type=np.float32):
        src_type = np.float16
        if self.is_channel_split:
            c0size = 8
        else:
            c0size = 16
        if self.is_trans_a:
            x1_shape = [self.b, self.k, self.m]
        else:
            x1_shape = [self.b, self.m, self.k]
        if self.is_trans_b:
            x2_shape = [self.b, self.n, self.k]
        else:
            x2_shape = [self.b, self.k, self.n]
        y_shape = [self.b, self.m, self.n]
        x1_gm = np.random.uniform(-1, 1, x1_shape).astype(src_type)
        x1_gm_fp32 = x1_gm.astype(np.float32)
        x2_gm = np.random.uniform(-1, 1, x2_shape).astype(src_type)
        x2_gm_fp32 = x2_gm.astype(np.float32)
        if self.is_bias:
            bias_gm = np.random.uniform(-1, 1, [1, self.n]).astype(dst_type)
            bias_gm_fp32 = bias_gm.astype(np.float32)

        if self.is_bias:
            y_gm_fp32 = self.tf_matmul(x1_gm_fp32, x2_gm_fp32, bias_gm_fp32)
        else:
            y_gm_fp32 = self.tf_matmul(x1_gm_fp32, x2_gm_fp32)
        y_gm = y_gm_fp32.astype(dst_type)

        if self.a_format == "NZ":
            x1_gm = MatmulGenData.nd_to_nz(x1_gm, x1_shape, src_type, c0size)
        if self.b_format == "NZ":
            x2_gm = MatmulGenData.nd_to_nz(x2_gm, x2_shape, src_type, c0size)
        if self.c_format == "NZ":
            y_gm = MatmulGenData.nd_to_nz(y_gm, y_shape, dst_type, c0size)

        x1_gm.tofile(work_dir + "/input/x1_gm.bin")
        x2_gm.tofile(work_dir + "/input/x2_gm.bin")
        y_gm.tofile(work_dir + "/output/golden.bin")
        if self.is_bias:
            bias_gm.tofile(work_dir + "/input/bias_gm.bin")

        if IS_OUTPUT_TXT:
            np.savetxt(work_dir + "/input/x1_gm.txt", x1_gm_fp32.flatten(), fmt='%f', newline='\n')
            np.savetxt(work_dir + "/input/x2_gm.txt", x2_gm_fp32.flatten(), fmt='%f', newline='\n')
            np.savetxt(work_dir + "/output/golden.txt", y_gm_fp32.astype(np.float32).flatten(), fmt='%f', newline='\n')
            if self.is_bias:
                np.savetxt(work_dir + "/input/bias_gm.txt", bias_gm_fp32.flatten(), fmt='%f', newline='\n')
        return 0


    def gen_golden_data(self, work_dir):
        if self.data_type_str == "float16_float32":
            self.gen_golden_data_fp16(work_dir)
        elif self.data_type_str == "float16_float16":
            self.gen_golden_data_fp16(work_dir, np.float16)
        else:
            logging.info("[ERROR] can't support data type %s" % (self.data_type_str))
            return -1
        return 0


    def gen_fake_golden_data(self, work_dir):
        data_type_bytes_ab = 2 # float16
        data_type_bytes_c = 4  # float32

        file_byte = self.b * self.m * self.k * data_type_bytes_ab
        with open(work_dir + "/input/x1_gm.bin", 'wb') as file:
            file.truncate(file_byte)

        file_byte = self.b * self.k * self.n * data_type_bytes_ab
        with open(work_dir + "/input/x2_gm.bin", 'wb') as file:
            file.truncate(file_byte)

        if self.is_bias:
            file_byte = 1 * self.n * data_type_bytes_c
            with open(work_dir + "/input/bias_gm.bin", 'wb') as file:
                file.truncate(file_byte)
